import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical
from torch_geometric.nn import GATv2Conv
from torch_geometric.loader import DataLoader


# UPMSP_Model
class UPMSP_Model(nn.Module):
    """
    A model for solving the Job Shop Scheduling Problem (JSSP) using a Graph Neural Network (GNN).
    """
    def __init__(self, model_params):
        super(UPMSP_Model, self).__init__()
        self.model_params = model_params
        self.gnn_model = GNNModel(self.model_params)

    def forward(self, state, latent_var, len_state):
        """
        Forward pass of the JSSP model.
        
        Args:
            state: A state object containing 'node_feature', 'job_edges', and 'machine_edges'.
        
        Returns:
            A tuple of (probability distribution over actions, edge indices).
        """
        
        logit = self.gnn_model(state.x, state.edge_index_1, state.edge_1_feature, state.edge_index_2, state.edge_index_3, state.edge_3_feature, latent_var, len_state)
        logit = logit.reshape(-1)
        logit = logit + torch.log(state.mask)
        return logit
        
    def get_policy(self, state):
        """
        Retrieves the current policy given the state.

        Args:
            state: A state object containing environment information.

        Returns:
            A Categorical distribution of the policy and edge indices.
        """
        prob = self.forward(state)
        policy = Categorical(prob)
        return policy

    def get_action(self, state, latent_var):
        """
        Samples an action according to the current policy.
        
        Args:
            state: A state object containing environment information.
            
        Returns:
            The action item and the log probability of the action.
        """
        x = DataLoader(state, batch_size = len(state))
        for s in x:
            break
        s = s.to('cuda')
        logits = self.forward(s, latent_var, len(state)).reshape(len(state), -1)
        prob = F.softmax(logits, -1)
        policy = Categorical(prob)
        action = policy.sample()
        log_prob = policy.log_prob(action)
        return action, log_prob
    
    def get_max_action(self, state):
        """
        Samples an action according to the current policy.
        
        Args:
            state: A state object containing environment information.
            
        Returns:
            The action item and the log probability of the action.
        """
        x = DataLoader(state, batch_size = len(state))
        for s in x:
            break
        s = s.to('cuda')
        logits = self.forward(s).reshape(len(state), -1)
        prob = F.softmax(logits, -1)
        policy = Categorical(prob)
        action = torch.argmax(policy.probs, -1)
        return action

class GNNModel(nn.Module):
    """
    A Graph Neural Network model utilizing GATv2Conv layers.
    """
    def __init__(self, model_params):
        super(GNNModel, self).__init__()
        self.model_params = model_params
        self.input_dim = self.model_params["input_dim"]
        self.embedding_dim = self.model_params["embedding_dim"]
        self.head = self.model_params["head_num"]
        self.encoder_layer_num = self.model_params['encoder_layer_num']
        self.lc_dim = self.model_params['latent_cont_dim']
        self.ld_dim = self.model_params['latent_disc_dim']
        self.l_dim = self.lc_dim+self.ld_dim

        self.embedding = nn.Sequential(nn.Linear(self.input_dim, self.embedding_dim), nn.ReLU()) 

        self.sijk_layers = nn.ModuleList([GATv2Conv(self.embedding_dim, self.embedding_dim//self.head, edge_dim= 1 , 
                                                       heads=self.head, dropout=0.2) for _ in range(self.encoder_layer_num)])
        self.sijk_backward_layers = nn.ModuleList([GATv2Conv(self.embedding_dim, self.embedding_dim//self.head, edge_dim= 1 , 
                                                       heads=self.head, dropout=0.2) for _ in range(self.encoder_layer_num)])

        
        self.same_job_layers = nn.ModuleList([GATv2Conv(self.embedding_dim, self.embedding_dim//self.head,
                                                       heads=self.head, dropout=0.2) for _ in range(self.encoder_layer_num)])
        
        self.same_machine_layers = nn.ModuleList([GATv2Conv(self.embedding_dim, self.embedding_dim//self.head, edge_dim= 1 , 
                                                       heads=self.head, dropout=0.2) for _ in range(self.encoder_layer_num)])

        self.concat_layers = nn.Linear(4 * self.embedding_dim, self.embedding_dim)
        self.decoding_layers = nn.Sequential(nn.Linear(self.embedding_dim+self.l_dim, self.embedding_dim), nn.ReLU(),
                                             nn.Linear(self.embedding_dim, 1))
        self.relu = nn.ReLU() #


    def forward(self, node_feature, edge_index_1, edge_feature_1, edge_index_2, edge_index_3, edge_feature_3, latent_var, len_state):

        """
        Forward pass of the GNN model.
        
        Args:
            node_feature: Node features for the graph.
            job_edge_index: Edge indices for job connections.
            machine_edge_index: Edge indices for machine connections.
            
        Returns:
            Decoded node features after passing through the GNN layers.
        """
        bs, embed = latent_var.shape
        ins_size = node_feature.size(0)//len_state
        edge_index_1_backward = edge_index_1.flip(dims=[0]).clone()

        embedding = self.embedding(node_feature)
        for layer in range(self.encoder_layer_num):
            embedding_job_forward = self.sijk_layers[layer](embedding, edge_index_1, edge_feature_1)
            embedding_job_backward = self.sijk_backward_layers[layer](embedding, edge_index_1_backward, edge_feature_1)
            embedding_machine = self.same_job_layers[layer](embedding, edge_index_2)
            embedding_machine2 =  self.same_machine_layers[layer](embedding, edge_index_3, edge_feature_3)
            embedding = self.relu(
                embedding + self.concat_layers(torch.cat([embedding_job_forward, embedding_job_backward, embedding_machine, embedding_machine2], dim=1))
            )

        latent_var = latent_var.reshape(bs, 1, embed)
        latent_var = latent_var.expand(bs, ins_size, embed)
        latent_var = latent_var.reshape(bs*ins_size, embed)
        cond_embedding = torch.cat([embedding, latent_var], dim=-1)
        return self.decoding_layers(cond_embedding)